archived/visualization/utils.py (290 lines of code) (raw):
import re
import warnings
from enum import Enum
import numpy as np
import pandas as pd
from IPython.display import HTML, display
from matplotlib import pyplot as plt
plt.style.use("seaborn-muted")
##### TABLE
def group_by_feature(baseline_statistics, latest_statistics, violations):
features = {}
# add baseline statistics
if baseline_statistics:
for baseline_feature in baseline_statistics["features"]:
feature_name = baseline_feature["name"]
if feature_name not in features:
features[feature_name] = {}
features[feature_name]["baseline"] = baseline_feature
# add latest statistics
if latest_statistics:
for latest_feature in latest_statistics["features"]:
feature_name = latest_feature["name"]
if feature_name not in features:
features[feature_name] = {}
features[feature_name]["latest"] = latest_feature
# add violations
if violations:
for violation in violations:
feature_name = violation["feature_name"]
if feature_name not in features:
features[feature_name] = {}
if "violations" in features[feature_name]:
features[feature_name]["violations"] += [violation]
else:
features[feature_name]["violations"] = [violation]
return features
def violation_exists(feature, check_type):
if "violations" in feature:
if check_type in set([v["constraint_check_type"] for v in feature["violations"]]):
return True
return False
def create_data_type_df(feature_names, features):
columns = ["data_type"]
rows = []
rows_style = []
for feature_name in feature_names:
feature = features[feature_name]
latest = feature["latest"]["inferred_type"]
violation = violation_exists(feature, "data_type_check")
rows.append([latest])
rows_style.append([violation])
df = pd.DataFrame(rows, index=feature_names, columns=columns)
df_style = pd.DataFrame(rows_style, index=feature_names, columns=columns)
return df, df_style
def get_completeness(feature):
if feature["inferred_type"] in set(["Fractional", "Integral"]):
common = feature["numerical_statistics"]["common"]
elif feature["inferred_type"] == "String":
common = feature["string_statistics"]["common"]
else:
raise ValueError("Unknown `inferred_type` {}.".format(feature["inferred_type"]))
num_present = common["num_present"]
num_missing = common["num_missing"]
completeness = num_present / (num_present + num_missing)
return completeness
def create_completeness_df(feature_names, features):
columns = ["completeness"]
rows = []
rows_style = []
for feature_name in feature_names:
feature = features[feature_name]
latest = get_completeness(feature["latest"])
violation = violation_exists(feature, "completeness_check")
rows.append([latest])
rows_style.append([violation])
df = pd.DataFrame(rows, index=feature_names, columns=columns)
df_style = pd.DataFrame(rows_style, index=feature_names, columns=columns)
return df, df_style
def get_baseline_drift(feature):
if "violations" in feature:
for violation in feature["violations"]:
if violation["constraint_check_type"] == "baseline_drift_check":
desc = violation["description"]
matches = re.search("distance: (.+) exceeds", desc)
if matches:
match = matches.group(1)
return float(match)
return np.nan
def create_baseline_drift_df(feature_names, features):
columns = ["baseline_drift"]
rows = []
rows_style = []
for feature_name in feature_names:
feature = features[feature_name]
latest = get_baseline_drift(feature)
violation = violation_exists(feature, "baseline_drift_check")
rows.append([latest])
rows_style.append([violation])
df = pd.DataFrame(rows, index=feature_names, columns=columns)
df_style = pd.DataFrame(rows_style, index=feature_names, columns=columns)
return df, df_style
def get_categorical_values(feature):
if "violations" in feature:
for violation in feature["violations"]:
if violation["constraint_check_type"] == "categorical_values_check":
desc = violation["description"]
matches = re.search("Value: (.+) does not meet the constraint requirement!", desc)
if matches:
match = matches.group(1)
return float(match)
return np.nan
def create_categorical_values_df(feature_names, features):
columns = ["categorical_values"]
rows = []
rows_style = []
for feature_name in feature_names:
feature = features[feature_name]
latest = get_categorical_values(feature)
violation = violation_exists(feature, "categorical_values_check")
rows.append([latest])
rows_style.append([violation])
df = pd.DataFrame(rows, index=feature_names, columns=columns)
df_style = pd.DataFrame(rows_style, index=feature_names, columns=columns)
return df, df_style
def create_violation_df(baseline_statistics, latest_statistics, violations):
features = group_by_feature(baseline_statistics, latest_statistics, violations)
feature_names = list(features.keys())
feature_names.sort()
data_type_df, data_type_df_style = create_data_type_df(feature_names, features)
completeness_df, completeness_df_style = create_completeness_df(feature_names, features)
baseline_drift_df, baseline_drift_df_style = create_baseline_drift_df(feature_names, features)
categorical_values_df, categorical_values_df_style = create_categorical_values_df(
feature_names, features
)
df = pd.concat(
[data_type_df, completeness_df, baseline_drift_df, categorical_values_df], axis=1
)
df_style = pd.concat(
[
data_type_df_style,
completeness_df_style,
baseline_drift_df_style,
categorical_values_df_style,
],
axis=1,
)
return df, df_style
def style_violation_df(df, df_style):
def all_white(df):
attr = "background-color: white"
return pd.DataFrame(attr, index=df.index, columns=df.columns)
def highlight_failed_row(df):
nonlocal df_style
df_style_cp = df_style.copy()
values = df_style_cp.values.any(axis=1, keepdims=True) * np.ones_like(df_style)
df_style_cp = pd.DataFrame(values, index=df.index, columns=df.columns)
df_style_cp = df_style_cp.replace(to_replace=True, value="background-color: #fff7dc")
df_style_cp = df_style_cp.replace(to_replace=False, value="")
return df_style_cp
def highlight_failed(df):
nonlocal df_style
df_style_cp = df_style.copy()
df_style_cp = df_style_cp.replace(to_replace=True, value="background-color: orange")
df_style_cp = df_style_cp.replace(to_replace=False, value="")
return df_style_cp
def style_percentage(value):
if np.isnan(value):
return "N/A"
else:
return "{:.2%}".format(value)
for column_name in ["completeness", "baseline_drift", "categorical_values"]:
df[column_name] = df[column_name].apply(style_percentage)
return (
df.style.apply(all_white, axis=None)
.apply(highlight_failed_row, axis=None)
.apply(highlight_failed, axis=None)
)
def show_violation_df(baseline_statistics, latest_statistics, violations):
violation_df, violation_df_style = create_violation_df(
baseline_statistics, latest_statistics, violations
)
return style_violation_df(violation_df, violation_df_style)
##### VISUALIZATION
def get_features(raw_data):
return {feature["name"]: feature for feature in raw_data["features"]}
def show_distributions(features, baselines=None):
string_features = [
name
for name, feature in features.items()
if FeatureType(feature["inferred_type"]) == FeatureType.STRING
]
numerical_features = [name for name, feature in features.items() if name not in string_features]
numerical_table = (
pd.concat([_summary_stats(features[feat]) for feat in numerical_features], axis=0)
if numerical_features
else None
)
string_table = (
pd.concat([_summary_stats(features[feat]) for feat in string_features], axis=0)
if string_features
else None
)
if numerical_features:
display(HTML("<h3>{msg}</h3>".format(msg="Numerical Features")))
display(numerical_table)
_display_charts(_get_charts(features, numerical_features, baselines))
if string_features:
display(HTML("<h3>{msg}</h3>".format(msg="String Features")))
display(string_table)
_display_charts(_get_charts(features, string_features, baselines), numerical=False)
def _display_charts(chart_tables, ncols=5, numerical=True):
nrows = int(np.ceil(len(chart_tables) / ncols))
fig, ax = plt.subplots(nrows, ncols, figsize=(20, 4 * nrows))
for i, chart_table in enumerate(chart_tables):
row, col = i // 5, i % 5
curr_ax = ax[row][col] if nrows > 1 else ax[col]
opacity = 0.7
if numerical:
c = chart_table[0].sort_values(by=["lower_bound"])
c_width = c.upper_bound.values[0] - c.lower_bound.values[0]
pos_c = 0.5 * (c.upper_bound.values + c.lower_bound.values)
else:
c = (
chart_table[0].sort_values(by=["frequency"], ascending=False).iloc[:10]
if len(chart_table[0]) > 10
else chart_table[0].sort_values(by=["frequency"], ascending=False)
)
c_width = 0.35
pos_c = np.arange(len(c.value.values))
curr_ax.bar(pos_c, c.frequency, c_width, label="collected", alpha=opacity)
if len(chart_table) > 1: # also includes baseline stats info
if numerical:
b = chart_table[1].sort_values(by=["lower_bound"])
b_width = b.upper_bound.values[0] - b.lower_bound.values[0]
pos_b = 0.5 * (b.upper_bound.values + b.lower_bound.values)
curr_ax.bar(pos_b, b.frequency, b_width, label="baseline", alpha=opacity)
else:
b = c.merge(chart_table[1], how="left", on=["value"])
b_width = 0.35
pos_b = np.arange(len(b.value.values)) + b_width
curr_ax.bar(pos_b, b.frequency_y, b_width, label="baseline", alpha=opacity)
curr_ax.legend()
if not numerical:
curr_ax.set_xticks(pos_c + c_width / 2)
curr_ax.set_xticklabels(
[label[:10] if len(label) > 10 else label for label in c.value.values],
)
[(tick.set_rotation(90), tick.set_fontsize(8)) for tick in curr_ax.get_xticklabels()]
curr_ax.set_xlabel(c.key.values[0])
plt.ylabel("Frequency")
if ncols * nrows != len(chart_tables):
[a.set_visible(False) for a in ax.flat[-(ncols * nrows - len(chart_tables)) :]]
plt.show()
def _get_charts(features, feature_types, baselines=None):
charts = (
[(_extract_dist(features[feat]), _extract_dist(baselines[feat])) for feat in feature_types]
if baselines is not None
else [(_extract_dist(features[feat]),) for feat in feature_types]
)
return [chart for chart in charts if not chart[0].empty]
def _extract_dist(feature_dict):
try:
stats_key = (
"string_statistics"
if FeatureType(feature_dict["inferred_type"]) == FeatureType.STRING
else "numerical_statistics"
)
distribution_type = (
"categorical"
if FeatureType(feature_dict["inferred_type"]) == FeatureType.STRING
else "kll"
)
table = pd.DataFrame(feature_dict[stats_key]["distribution"][distribution_type]["buckets"])
table["frequency"] = table["count"] / table["count"].sum()
table["key"] = [feature_dict["name"]] * len(table)
except KeyError:
table = pd.DataFrame()
return table
def _summary_stats(feature_dict):
stats_key = (
"string_statistics"
if FeatureType(feature_dict["inferred_type"]) == FeatureType.STRING
else "numerical_statistics"
)
common = pd.DataFrame(feature_dict[stats_key]["common"], index=[feature_dict["name"]])
specific = pd.DataFrame(
{k: v for k, v in feature_dict[stats_key].items() if k != "common" and k != "distribution"},
index=[feature_dict["name"]],
)
return pd.concat([common, specific], axis=1)
class FeatureType(Enum):
INTEGRAL = "Integral"
FRACTIONAL = "Fractional"
STRING = "String"
UNKNOWN = "Unknown"